{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b4268da6-98d2-4e23-a984-ce49c70d6a42",
   "metadata": {},
   "source": [
    "# LLM Streaming Client\n",
    "\n",
    "This notebook demonstrates how to stream responses from the LLM. \n",
    "\n",
    "### Triton Inference Server\n",
    "The LLM has been deployed to [NVIDIA Triton Inference Server](https://developer.nvidia.com/triton-inference-server) and leverages NVIDIA TensorRT-LLM (TRT-LLM), so it's optimized for low latency and high throughput inference.\n",
    "\n",
    "The Triton client is used to communicate with the inference server hosting the LLM and is available in [LangChain](https://github.com/langchain-ai/langchain-nvidia/tree/main/libs/trt). \n",
    "\n",
    "### Streaming LLM Responses\n",
    "TRT-LLM on its own can provide drastic improvements to LLM response latency, but streaming can take the user-experience to the next level. Instead of waiting for an entire response to be returned from the LLM, chunks of it can be processed as soon as they are available. This helps reduce the perceived latency by the user. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "667181db-04d8-4c9d-b433-26c2a14d54e7",
   "metadata": {},
   "source": [
    "### Step 1: Structure the Query in a Prompt Template"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e206005-d153-49ce-8b54-41e425a7de17",
   "metadata": {},
   "source": [
    "A [**prompt template**](https://gpt-index.readthedocs.io/en/stable/api_reference/prompts.html) is a common paradigm in LLM development. \n",
    "\n",
    "They are a pre-defined set of instructions provided to the LLM and guide the output produced by the model. They can contain few shot examples and guidance and are a quick way to engineer the responses from the LLM. Llama 2 accepts the [prompt format](https://huggingface.co/blog/llama2#how-to-prompt-llama-2) shown in `LLAMA_PROMPT_TEMPLATE`, which we modify to be constructed with:\n",
    "- The system prompt\n",
    "- The context\n",
    "- The user's question"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "42a2f2cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "LLAMA_PROMPT_TEMPLATE = (\n",
    " \"<s>[INST] <<SYS>>\"\n",
    " \"{system_prompt}\"\n",
    " \"<</SYS>>\"\n",
    " \"[/INST] {context} </s><s>[INST] {question} [/INST]\"\n",
    ")\n",
    "system_prompt = \"You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Please ensure that your responses are positive in nature.\"\n",
    "context=\"\"\n",
    "question='What is the fastest land animal?'\n",
    "prompt = LLAMA_PROMPT_TEMPLATE.format(system_prompt=system_prompt, context=context, question=question)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e975c7b-3c5e-4ba6-954a-0064e6a91245",
   "metadata": {},
   "source": [
    "### Step 2: Create the Triton Client"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2858c80-ba39-43be-9978-0f8be6e6c3dd",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-block alert-warning\">\n",
    "<b>WARNING!</b> Be sure to replace `triton_url` with the address and port that Triton is running on. \n",
    "</div>\n",
    "\n",
    "Use the address and port that the Triton is available on; for example `localhost:8001`. \n",
    "\n",
    "**If you are running this notebook as part of the AI workflow, you dont have to replace the url**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "5670011e-f52b-4c16-be4d-8a782b622541",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_nvidia_trt.llms import TritonTensorRTLLM\n",
    "\n",
    "triton_url = \"llm:8001\"\n",
    "pload = {\n",
    "            'tokens':300,\n",
    "            'server_url': triton_url,\n",
    "            'model_name': \"ensemble\",\n",
    "            'temperature':1.0,\n",
    "            'top_k':1,\n",
    "            'top_p':0,\n",
    "            'beam_width':1,\n",
    "            'repetition_penalty':1.0,\n",
    "            'length_penalty':1.0\n",
    "}\n",
    "client = TritonTensorRTLLM(**pload)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03676629-33d8-46d8-b5fc-557d526609b4",
   "metadata": {},
   "source": [
    "Additional inputs to the LLM can be modified:\n",
    "- tokens: the maximum number of tokens (words/sub-words) generated\n",
    "- temperature: [0,1] -- higher values produce more diverse outputs\n",
    "- [top_k](https://docs.cohere.com/docs/controlling-generation-with-top-k-top-p): sample from the k most likely next tokens at each step; lower value will concentrate sampling on the highest probability tokens for each step (reduces variety)\n",
    "- [top_p](https://docs.cohere.com/docs/controlling-generation-with-top-k-top-p): [0, 1] -- cumulative probability cutoff for token selection; lower values mean sampling from a smaller nucleus sample (reduces variety)\n",
    "- repetition_penalty: [1, 2] -- penalize repeated tokens\n",
    "- length_penalty: 1 means no penalty for length of generation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c526b20b-258a-4eb7-87e6-5430d57e32ea",
   "metadata": {},
   "source": [
    "### Step 3: Load the Model and Stream Responses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "274e4164-7460-4471-8625-90562237cf11",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import random\n",
    "\n",
    "start_time = time.time()\n",
    "tokens_generated = 0\n",
    "\n",
    "for val in client.stream(prompt):\n",
    "    tokens_generated += 1\n",
    "    print(val, end=\"\", flush=True)\n",
    "\n",
    "total_time = time.time() - start_time\n",
    "print(f\"\\n--- Generated {tokens_generated} tokens in {total_time} seconds ---\")\n",
    "print(f\"--- {tokens_generated/total_time} tokens/sec\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
